package org.andy.kit.lock.redis;

import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.Maps;
import org.andy.kit.lock.DistributedReentrantLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPool;

/**
 * redis可重入的锁<br>
 *
 * @author andy
 * @version V1.0
 * @date 18/2/8 下午12:20
 */
public class RedisReentrantLock implements DistributedReentrantLock {

    private static final Logger LOG = LoggerFactory.getLogger(RedisReentrantLock.class);

    // 存放当前线程的锁信息
    private final ConcurrentMap<Thread, LockData> thradMap = Maps.newConcurrentMap();

    private JedisPool jedisPool;
    private RedisLock redisLock;
    private String lockId;

    public RedisReentrantLock(JedisPool jedisPool, String lockId) {
        this.jedisPool = jedisPool;
        this.lockId = lockId;
        this.redisLock = new RedisLock(jedisPool);
    }

    public static class LockData {
        // 记录锁的线程
        private Thread currentThread;

        final String localValue;
        // 计算获取锁的次数
        final AtomicInteger lockCount = new AtomicInteger(1);

        public LockData(Thread currentThread, String localValue) {
            this.currentThread = currentThread;
            this.localValue = localValue;
        }
    }

    @Override
    public boolean tryLock(long timeout, TimeUnit unit) {
        Thread currentThread = Thread.currentThread();
        LockData lockData = thradMap.get(currentThread);

        // 已经获取锁，重入，增加获取锁的次数
        if (lockData != null) {
            lockData.lockCount.incrementAndGet();
            return true;
        }

        // 获取锁
        String lockValue = redisLock.tryRedisLock(lockId, timeout, unit);
        if (lockValue != null) {
            // 成功获取锁
            thradMap.put(currentThread, new LockData(currentThread, lockValue));
            return true;
        }

        return false;
    }

    @Override
    public void unlock() {
        Thread currentThread = Thread.currentThread();
        LockData lockData = thradMap.get(currentThread);

        if (lockData == null) {
            throw new IllegalMonitorStateException("You do not get the lock: " + lockId);
        }

        // 同一线程多次获取锁，只需要计数减一
        int lockCount = lockData.lockCount.decrementAndGet();
        if (lockCount > 0){
            return;
        }

        if (lockCount < 0){
            throw new IllegalMonitorStateException("Lock count < 0 for lockId: " + lockId);
        }

        try {
            // 释放锁
            redisLock.unRedisLock(lockId, lockData.localValue);
        }finally {
            thradMap.remove(currentThread);
        }

    }

}
